{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using Paddle DALI plugin: using various readers\n",
    "\n",
    "### Overview\n",
    "\n",
    "This example shows how different readers could be used to interact with Paddle. It shows how flexible DALI is.\n",
    "\n",
    "The following readers are used in this example:\n",
    "\n",
    "- readers.mxnet\n",
    "- readers.caffe\n",
    "- readers.file\n",
    "- readers.tfrecord\n",
    "\n",
    "For details on how to use them please see other [examples](../../index.rst)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us start from defining some global constants\n",
    "\n",
    "`DALI_EXTRA_PATH` environment variable should point to the place where data from [DALI extra repository](https://github.com/NVIDIA/DALI_extra) is downloaded. Please make sure that the proper release tag is checked out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path\n",
    "import subprocess\n",
    "\n",
    "test_data_root = os.environ[\"DALI_EXTRA_PATH\"]\n",
    "\n",
    "# MXNet RecordIO\n",
    "db_folder = os.path.join(test_data_root, \"db\", \"recordio/\")\n",
    "\n",
    "# Caffe LMDB\n",
    "lmdb_folder = os.path.join(test_data_root, \"db\", \"lmdb\")\n",
    "\n",
    "# image dir with plain jpeg files\n",
    "image_dir = \"../../data/images\"\n",
    "\n",
    "# TFRecord\n",
    "tfrecord = os.path.join(test_data_root, \"db\", \"tfrecord\", \"train\")\n",
    "tfrecord_idx = \"idx_files/train.idx\"\n",
    "tfrecord2idx_script = \"tfrecord2idx\"\n",
    "\n",
    "res = subprocess.run([\"nvidia-smi\", \"-L\"], stdout=subprocess.PIPE, text=True)\n",
    "N = res.stdout.count(\"\\n\")  # number of GPUs\n",
    "BATCH_SIZE = 128  # batch size per GPU\n",
    "IMAGE_SIZE = 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create idx file by calling `tfrecord2idx` script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from subprocess import call\n",
    "import os.path\n",
    "\n",
    "if not os.path.exists(\"idx_files\"):\n",
    "    os.mkdir(\"idx_files\")\n",
    "\n",
    "if not os.path.isfile(tfrecord_idx):\n",
    "    call([tfrecord2idx_script, tfrecord, tfrecord_idx])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us define:\n",
    "- common part of the processing graph, used by all pipelines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nvidia.dali import pipeline_def, Pipeline\n",
    "import nvidia.dali.fn as fn\n",
    "import nvidia.dali.types as types\n",
    "\n",
    "\n",
    "def common_pipeline(jpegs, labels):\n",
    "    images = fn.decoders.image(jpegs, device=\"mixed\")\n",
    "    images = fn.resize(\n",
    "        images,\n",
    "        resize_shorter=fn.random.uniform(range=(256, 480)),\n",
    "        interp_type=types.INTERP_LINEAR,\n",
    "    )\n",
    "    images = fn.crop_mirror_normalize(\n",
    "        images,\n",
    "        crop_pos_x=fn.random.uniform(range=(0.0, 1.0)),\n",
    "        crop_pos_y=fn.random.uniform(range=(0.0, 1.0)),\n",
    "        dtype=types.FLOAT,\n",
    "        crop=(227, 227),\n",
    "        mean=[128.0, 128.0, 128.0],\n",
    "        std=[1.0, 1.0, 1.0],\n",
    "    )\n",
    "    return images, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- MXNet reader pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@pipeline_def\n",
    "def mxnet_reader_pipeline(num_gpus):\n",
    "    jpegs, labels = fn.readers.mxnet(\n",
    "        path=[db_folder + \"train.rec\"],\n",
    "        index_path=[db_folder + \"train.idx\"],\n",
    "        random_shuffle=True,\n",
    "        shard_id=Pipeline.current().device_id,\n",
    "        num_shards=num_gpus,\n",
    "        name=\"Reader\",\n",
    "    )\n",
    "\n",
    "    return common_pipeline(jpegs, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Caffe reader pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "@pipeline_def\n",
    "def caffe_reader_pipeline(num_gpus):\n",
    "    jpegs, labels = fn.readers.caffe(\n",
    "        path=lmdb_folder,\n",
    "        random_shuffle=True,\n",
    "        shard_id=Pipeline.current().device_id,\n",
    "        num_shards=num_gpus,\n",
    "        name=\"Reader\",\n",
    "    )\n",
    "\n",
    "    return common_pipeline(jpegs, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- File reader pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "@pipeline_def\n",
    "def file_reader_pipeline(num_gpus):\n",
    "    jpegs, labels = fn.readers.file(\n",
    "        file_root=image_dir,\n",
    "        random_shuffle=True,\n",
    "        shard_id=Pipeline.current().device_id,\n",
    "        num_shards=num_gpus,\n",
    "        name=\"Reader\",\n",
    "    )\n",
    "\n",
    "    return common_pipeline(jpegs, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- TFRecord reader pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nvidia.dali.tfrecord as tfrec\n",
    "\n",
    "\n",
    "@pipeline_def\n",
    "def tfrecord_reader_pipeline(num_gpus):\n",
    "    inputs = fn.readers.tfrecord(\n",
    "        path=tfrecord,\n",
    "        index_path=tfrecord_idx,\n",
    "        features={\n",
    "            \"image/encoded\": tfrec.FixedLenFeature((), tfrec.string, \"\"),\n",
    "            \"image/class/label\": tfrec.FixedLenFeature([1], tfrec.int64, -1),\n",
    "        },\n",
    "        random_shuffle=True,\n",
    "        shard_id=Pipeline.current().device_id,\n",
    "        num_shards=num_gpus,\n",
    "        name=\"Reader\",\n",
    "    )\n",
    "\n",
    "    return common_pipeline(inputs[\"image/encoded\"], inputs[\"image/class/label\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us create pipelines and pass them to Paddle generic iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RUN: mxnet_reader_pipeline\n",
      "OK : mxnet_reader_pipeline\n",
      "RUN: caffe_reader_pipeline\n",
      "OK : caffe_reader_pipeline\n",
      "RUN: file_reader_pipeline\n",
      "OK : file_reader_pipeline\n",
      "RUN: tfrecord_reader_pipeline\n",
      "OK : tfrecord_reader_pipeline\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from nvidia.dali.plugin.paddle import DALIGenericIterator\n",
    "\n",
    "\n",
    "pipe_types = [\n",
    "    [mxnet_reader_pipeline, (0, 999)],\n",
    "    [caffe_reader_pipeline, (0, 999)],\n",
    "    [file_reader_pipeline, (0, 1)],\n",
    "    [tfrecord_reader_pipeline, (1, 1000)],\n",
    "]\n",
    "\n",
    "for pipe_t in pipe_types:\n",
    "    pipe_name, label_range = pipe_t\n",
    "    print(\"RUN: \" + pipe_name.__name__)\n",
    "    pipes = [\n",
    "        pipe_name(\n",
    "            batch_size=BATCH_SIZE,\n",
    "            num_threads=2,\n",
    "            device_id=device_id,\n",
    "            num_gpus=N,\n",
    "        )\n",
    "        for device_id in range(N)\n",
    "    ]\n",
    "    dali_iter = DALIGenericIterator(\n",
    "        pipes, [\"data\", \"label\"], reader_name=\"Reader\"\n",
    "    )\n",
    "\n",
    "    for i, data in enumerate(dali_iter):\n",
    "        # Testing correctness of labels\n",
    "        for d in data:\n",
    "            label = d[\"label\"]\n",
    "            image = d[\"data\"]\n",
    "            ## labels need to be integers\n",
    "            assert np.equal(np.mod(label, 1), 0).all()\n",
    "            ## labels need to be in range pipe_name[2]\n",
    "            assert (np.array(label) >= label_range[0]).all()\n",
    "            assert (np.array(label) <= label_range[1]).all()\n",
    "    print(\"OK : \" + pipe_name.__name__)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
